"""
#
# Flow stability for dynamic community detection https://arxiv.org/abs/2101.06131v2
#
# Copyright (C) 2021 Alexandre Bovet <alexandre.bovet@maths.ox.ac.uk>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


This script creates the figure S1 for the importance of time weighting.


"""
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import lambertw


N=8

F11_A = lambda t,l=1,N=8: (9/(32*N*l*t))*(1-np.exp(-8*l*t/3)) + (N-4)/(4*N**2)
F12_A = lambda t,l=1,N=8: (3/(32*N*l*t))*(np.exp(-8*l*t/3)-1) + (N-4)/(4*N**2)



#%% analytic solution

# for t>tstar
def F11_B(t,tstar,l=1, N=8):
    return -1/N**2 + (1/(32*N*l*t))*( np.exp(-8*l*tstar/3)*(8*l*(t-tstar)-4*np.exp(-4*l*(t-tstar)) - 5) + 9 + 8*l*t)

def F12_B(t,tstar,l=1,N=8):
    return -1/N**2 + (1/(32*N*l*t))*( np.exp(-8*l*tstar/3)*(8*l*(t-tstar)+4*np.exp(-4*l*(t-tstar)) - 1) - 3 + 8*l*t)



def F13_B(t,tstar,l=1,N=8):
    return -1/N**2 + (1/(32*N*l*t))*( np.exp(-8*l*tstar/3)*(3-8*l*(t-tstar)) - 3 + 8*l*t)



t13= lambda l,N=8: (3/(8*l))*np.log(N/(N-4))

#%% Plot F_12(t)/F_13(t)

def F_12(t,tstar,l=1,N=8):
    if t<= tstar:
        return F12_A(t,l,N)
    else:
        return F12_B(t,tstar,l,N)

def F_13(t,tstar,l=1,N=8):
    if t<= tstar:
        return F12_A(t,l,N)
    else:
        return F13_B(t,tstar,l,N)
    





#%% analytic
def F11_back(t,tstar,l=1,N=8):
    if t <= tstar:
        return (1/(8*N*l*t))*(1-np.exp(-4*l*t))+(N-2)/(2*N**2)
    else:
        return -1/N**2 + (1/(32*N*l*t))*(np.exp(-4*l*tstar)*(2-6*np.exp(-8*l*(t-tstar)/3)-3*np.exp(-4*l*(2*t-5*tstar)/3)) + 7 + 8*l*(t+tstar))




def F12_back(t,tstar,l=1,N=8):
    if t <= tstar:
        return (1/(8*N*l*t))*(np.exp(-4*l*t)-1)+(N-2)/(2*N**2)
    else:
        return -1/N**2 + (1/(32*N*l*t))*(np.exp(-4*l*tstar)*(-2+6*np.exp(-8*l*(t-tstar)/3)-3*np.exp(-4*l*(2*t-5*tstar)/3)) + -1 + 8*l*(t+tstar))



def F13_back(t,tstar,l=1,N=8):
    if t <= tstar:
        return -1/N**2
    else:
        return -1/N**2 + (1/(32*N*l*t))*(3*(np.exp(-8*l*(t-tstar)/3)-1) + 8*l*(t-tstar))



def t_hatb3(tstar, l=1,N=8):
    assert -(N/(N-4))*np.exp((3*N+32*l*tstar)/(12-3*N)) > -np.exp(-1) # for real solutions
    return (1/(8*l*(N-4)))*(3*N + 8*l*N*tstar + (3*N-12)*lambertw(-(N/(N-4))*np.exp(-(3*N+32*l*tstar)/(3*N-12))).real)

t_hatb = lambda l, N=8: (1/(4*l))*(N/(N-2)+lambertw(-(N/(N-2))*np.exp(-(N/(N-2)))).real)

#%% figure for paper
plt.style.use('alex_paper')
import matplotlib
matplotlib.rcParams['mathtext.fontset'] = 'stix'
matplotlib.rcParams['font.family'] = 'STIXGeneral'
fig,(ax,ax2) = plt.subplots(2,1, figsize=(8.5,8))
lw=1.5

col1 = '#603D93'
col2 = '#349354'
col3 = '#CF720E'

tstar_f = 2
tstar_b = 1

lambda_1 = 5

lambda_2 = 0.4

lambda_3 = 0.2


ts = np.linspace(1e-6,10,1000)

p1 = ax.plot(ts,[F_13(t, tstar_f, l=lambda_1) for t in ts],ls='--',lw=lw,
        label=f'$\lambda={lambda_1}$', color=col1)
p2 = ax.plot(ts,[F_13(t, tstar_f, l=lambda_2) for t in ts],ls='--',lw=lw,
        label=f'$\lambda={lambda_2}$', color=col2)
p3 = ax.plot(ts,[F_13(t, tstar_f, l=lambda_3) for t in ts],ls='--',lw=lw,
        label=f'$\lambda={lambda_3}$', color=col3)

p1 = ax.plot(ts,[F_12(t, tstar_f, l=lambda_1) for t in ts],ls='-',lw=lw,
        label=f'$\lambda={lambda_1}$', color=col1)
p2 = ax.plot(ts,[F_12(t, tstar_f, l=lambda_2) for t in ts],ls='-',lw=lw,
        label=f'$\lambda={lambda_2}$', color=col2)
p3 = ax.plot(ts,[F_12(t, tstar_f, l=lambda_3) for t in ts],ls='-',lw=lw,
        label=f'$\lambda={lambda_3}$', color=col3)

ax.hlines(0,0,10,'k')

# ax.set_xscale('log')
ax.set_xscale('linear')
ax.set_xlim([0,5])
# ax.legend()
ax.set_ylabel('$F_{12}^{forw}(t), F_{13}^{forw}(t)$')
ax.set_xlabel('$t$')

shift=0.03
yshiftp=0.002
yshiftn=0.004

ax.vlines([tstar_f],-0.001,0.001,'k')
ax.annotate('$t^\star$',(tstar_f-shift,yshiftp))

    

ax2.yaxis.set_label_position("right")
ax2.yaxis.tick_right()

ax2.plot(ts, [F13_back(t, tstar=tstar_b, l=lambda_1) for t in ts],
          ls='--',lw=lw,
                  label=f'$\lambda={lambda_1}$', color=col1)
ax2.plot(ts, [F13_back(t, tstar=tstar_b, l=lambda_2) for t in ts],
          ls='--',lw=lw,
        label=f'$\lambda={lambda_2}$', color=col2)
ax2.plot(ts, [F13_back(t, tstar=tstar_b, l=lambda_3) for t in ts],
          ls='--',lw=lw,
        label=f'$\lambda={lambda_3}$', color=col3)

ax2.plot(ts, [F12_back(t, tstar=tstar_b, l=lambda_1) for t in ts],
          ls='-',lw=lw,
                  label=f'$\lambda={lambda_1}$', color=col1)
ax2.plot(ts, [F12_back(t, tstar=tstar_b, l=lambda_2) for t in ts],
          ls='-',lw=lw,
        label=f'$\lambda={lambda_2}$', color=col2)
ax2.plot(ts, [F12_back(t, tstar=tstar_b, l=lambda_3) for t in ts],
          ls='-',lw=lw,
        label=f'$\lambda={lambda_3}$', color=col3)


ax2.set_ylabel('$F_{12}^{back}(t), F_{13}^{back}(t)$')
ax2.set_xlabel('$t$')

ax.legend(bbox_to_anchor=(1.04,1), loc="upper left")


ax2.hlines(0,0,10,'k')

ax2.vlines([tstar_b],-0.001,0.001,'k')
ax2.annotate('$t^\star$',(tstar_b+shift,yshiftp))



ax.set_xlim([0,3])
ax2.set_xlim([3,0])

ax.set_ylim([-1/N**2-0.001,1/N**2+0.001])
ax2.set_ylim([-0.03,+0.045])

#%% savefig

# plt.savefig('figures/paper_figures/analytic_time_weighting_rev.pdf',)
# plt.savefig('figures/paper_figures/analytic_time_weighting_rev.png',dpi=600)



